{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "C1rAsD2L-hSO"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6f8f3af-744e-4eaa-8a30-6d03e8e4d21e"
      },
      "source": [
        "# Apache Beam RunInference for scikit-learn\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_sklearn.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_sklearn.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "This notebook demonstrates the use of the RunInference transform for [scikit-learn](https://scikit-learn.org/), also called sklearn.\n",
        "Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) has implementations of the [ModelHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelHandler) class prebuilt for scikit-learn. For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
        "\n",
        "You can choose the appropriate model handler based on your input data type:\n",
        "* [NumPy model handler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.sklearn_inference.html#apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy)\n",
        "* [Pandas DataFrame model handler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.sklearn_inference.html#apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy)\n",
        "\n",
        "With RunInference, these model handlers manage batching, vectorization, and prediction optimization for your scikit-learn pipeline or model.\n",
        "\n",
        "This notebook demonstrates the following common RunInference patterns:\n",
        "*   Generate predictions.\n",
        "*   Postprocess results after RunInference.\n",
        "*   Run inference with multiple models in the same pipeline.\n",
        "\n",
        "The linear regression models used in these samples are trained on data that correspondes to the 5 and 10 times tables; that is,`y = 5x` and `y = 10x` respectively."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zzwnMzzgdyPB"
      },
      "source": [
        "## Before you begin\n",
        "Complete the following setup steps:\n",
        "1. Install dependencies for Apache Beam.\n",
        "1. Authenticate with Google Cloud.\n",
        "1. Specify your project and bucket. You use the project and bucket to save and load models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6vlKcT-Wev20",
        "outputId": "336e8afc-6716-41dd-a438-500353189c62"
      },
      "outputs": [],
      "source": [
        "!pip install google-api-core --quiet\n",
        "!pip install google-cloud-pubsub google-cloud-bigquery-storage --quiet\n",
        "!pip install apache-beam[gcp,dataframe] --quiet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32c9ba40-9396-48f4-9e7f-a2acced98bb2"
      },
      "source": [
        "## About scikit-learn versions\n",
        "\n",
        "`scikit-learn` is a build-dependency of Apache Beam. If you need to install a different version of sklearn , use `%pip install scikit-learn==<version>`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "V0E35R5Ka2cE"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "d6142b75-eef1-4e52-9fa4-fe02fe916b26"
      },
      "outputs": [],
      "source": [
        "import pickle\n",
        "from sklearn import linear_model\n",
        "from typing import Tuple\n",
        "\n",
        "import numpy as np\n",
        "import apache_beam as beam\n",
        "\n",
        "from apache_beam.ml.inference.sklearn_inference import ModelFileType\n",
        "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n",
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "\n",
        "# NOTE: If an error occurs, restart your runtime.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "248458a6-cfd8-474d-ad0e-f37f7ae981ae"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "# Constants\n",
        "project = \"<PROJECT_ID>\" # @param {type:'string'}\n",
        "bucket = \"<BUCKET_NAME>\" # @param {type:'string'}\n",
        "\n",
        "# To avoid warnings, set the project.\n",
        "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6695cd22-e0bf-438f-8223-4a93392e6616"
      },
      "source": [
        "## Create the data and the scikit-learn model\n",
        "This section demonstrates the following steps:\n",
        "1. Create the data to train the scikit-learn linear regression model.\n",
        "2. Train the linear regression model.\n",
        "3. Save the scikit-learn model using `pickle`.\n",
        "\n",
        "In this example, you create two models, one with the 5 times model and a second with the 10 times model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "c57843e8-f696-4196-ad39-827e34849976"
      },
      "outputs": [],
      "source": [
        "# Input data to train the sklearn model for the 5 times table.\n",
        "x = np.arange(0, 100, dtype=np.float32).reshape(-1, 1)\n",
        "y = (x * 5).reshape(-1, 1)\n",
        "\n",
        "def train_and_save_model(x, y, model_file_name):\n",
        "  regression = linear_model.LinearRegression()\n",
        "  regression.fit(x,y)\n",
        "\n",
        "  with open(model_file_name, 'wb') as f:\n",
        "      pickle.dump(regression, f)\n",
        "\n",
        "five_times_model_filename = 'sklearn_5x_model.pkl'\n",
        "train_and_save_model(x, y, five_times_model_filename)\n",
        "\n",
        "# Change y to be 10 times, and output a 10 times table.\n",
        "ten_times_model_filename = 'sklearn_10x_model.pkl'\n",
        "train_and_save_model(x, y, ten_times_model_filename)\n",
        "y = (x * 10).reshape(-1, 1)\n",
        "train_and_save_model(x, y, 'sklearn_10x_model.pkl')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "69008a3d-3d15-4643-828c-b0419b347d01"
      },
      "source": [
        "### Create a scikit-learn RunInference pipeline\n",
        "This section demonstrates how to do the following:\n",
        "1. Define a scikit-learn model handler that accepts an `array_like` object as input.\n",
        "2. Read the data from BigQuery.\n",
        "3. Use the scikit-learn trained model and the scikit-learn RunInference transform on unkeyed data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "AEGaqpMVqgRP"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade google-cloud-bigquery --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xq5AKtRrqlUx",
        "outputId": "fba8fb42-4958-451a-8aaa-9a838052a2f8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Updated property [core/project].\n"
          ]
        }
      ],
      "source": [
        "!gcloud config set project $project"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QCIjN__rpoVF",
        "outputId": "0ded224f-2272-482e-80f5-bb2d21b6f5d8"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<google.cloud.bigquery.table._EmptyRowIterator at 0x7f97abb4e850>"
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Populated BigQuery table\n",
        "\n",
        "from google.cloud import bigquery\n",
        "\n",
        "client = bigquery.Client(project=project)\n",
        "\n",
        "# Make sure the dataset_id is unique in your project.\n",
        "dataset_id = '{project}.maths'.format(project=project)\n",
        "dataset = bigquery.Dataset(dataset_id)\n",
        "\n",
        "# Modify the location based on your project configuration.\n",
        "dataset.location = 'US'\n",
        "dataset = client.create_dataset(dataset, exists_ok=True)\n",
        "\n",
        "# Table name in the BigQuery dataset.\n",
        "table_name = 'maths_problems_1'\n",
        "\n",
        "query = \"\"\"\n",
        "    CREATE OR REPLACE TABLE\n",
        "      {project}.maths.{table} ( key STRING OPTIONS(description=\"A unique key for the maths problem\"),\n",
        "    value FLOAT64 OPTIONS(description=\"Our maths problem\" ) );\n",
        "    INSERT INTO maths.{table}\n",
        "    VALUES\n",
        "      (\"first_example\", 105.00),\n",
        "      (\"second_example\", 108.00),\n",
        "      (\"third_example\", 1000.00),\n",
        "      (\"fourth_example\", 1013.00)\n",
        "\"\"\".format(project=project, table=table_name)\n",
        "\n",
        "create_job = client.query(query)\n",
        "create_job.result()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "50a648a3-794a-4286-ab2b-fc0458db04ca",
        "outputId": "8eab34b4-dcc7-4df1-ec0e-8c86a34d31c6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "PredictionResult(example=[1000.0], inference=array([5000.]))\n",
            "PredictionResult(example=[1013.0], inference=array([5065.]))\n",
            "PredictionResult(example=[108.0], inference=array([540.]))\n",
            "PredictionResult(example=[105.0], inference=array([525.]))\n"
          ]
        }
      ],
      "source": [
        "sklearn_model_handler = SklearnModelHandlerNumpy(model_uri=five_times_model_filename) \n",
        "\n",
        "\n",
        "pipeline_options = PipelineOptions().from_dictionary(\n",
        "                                      {'temp_location':f'gs://{bucket}/tmp'})\n",
        "\n",
        "# Define the BigQuery table specification.\n",
        "table_name = 'maths_problems_1'\n",
        "table_spec = f'{project}:maths.{table_name}'\n",
        "\n",
        "with beam.Pipeline(options=pipeline_options) as p:\n",
        "  (\n",
        "      p \n",
        "      | \"ReadFromBQ\" >> beam.io.ReadFromBigQuery(table=table_spec)\n",
        "      | \"ExtractInputs\" >> beam.Map(lambda x: [x['value']]) \n",
        "      | \"RunInferenceSklearn\" >> RunInference(model_handler=sklearn_model_handler)\n",
        "      | beam.Map(print)\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "33e901d6-ed06-4268-8a5f-685d31b5558f"
      },
      "source": [
        "### Use sklearn RunInference on keyed inputs\n",
        "This section demonstrates how to do the following:\n",
        "1. Wrap the `SklearnModelHandlerNumpy` object around `KeyedModelHandler` to handle keyed data.\n",
        "2. Read the data from BigQuery.\n",
        "3. Use the sklearn trained model and the sklearn RunInference transform on a keyed data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c212916d-b517-4589-ad15-a3a1df926fb3",
        "outputId": "61db2d76-4dfa-4b38-cf9a-645790b4c5aa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "('third_example', PredictionResult(example=[1000.0], inference=array([5000.])))\n",
            "('fourth_example', PredictionResult(example=[1013.0], inference=array([5065.])))\n",
            "('second_example', PredictionResult(example=[108.0], inference=array([540.])))\n",
            "('first_example', PredictionResult(example=[105.0], inference=array([525.])))\n"
          ]
        }
      ],
      "source": [
        "sklearn_model_handler = SklearnModelHandlerNumpy(model_uri=five_times_model_filename) \n",
        "keyed_sklearn_model_handler = KeyedModelHandler(sklearn_model_handler)\n",
        "\n",
        "pipeline_options = PipelineOptions().from_dictionary(\n",
        "                                      {'temp_location':f'gs://{bucket}/tmp'})\n",
        "with beam.Pipeline(options=pipeline_options) as p:\n",
        "  (\n",
        "  p \n",
        "  | \"ReadFromBQ\" >> beam.io.ReadFromBigQuery(table=table_spec)\n",
        "  | \"ExtractInputs\" >> beam.Map(lambda x: (x['key'], [x['value']])) \n",
        "  | \"RunInferenceSklearn\" >> RunInference(model_handler=keyed_sklearn_model_handler)\n",
        "  | beam.Map(print)\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JQ4zvlwsRK1W"
      },
      "source": [
        "## Run multiple models\n",
        "\n",
        "This code creates a pipeline that takes two RunInference transforms with different models and then combines the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 86,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0qMlX6SeR68D",
        "outputId": "5e4a0852-3761-47da-aa08-0386fd524a78"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key = third_example * 10, example = 1000.0 -> predictions 10000.0\n",
            "key = fourth_example * 10, example = 1013.0 -> predictions 10130.0\n",
            "key = second_example * 10, example = 108.0 -> predictions 1080.0\n",
            "key = first_example * 10, example = 105.0 -> predictions 1050.0\n",
            "key = third_example * 5, example = 1000.0 -> predictions 5000.0\n",
            "key = fourth_example * 5, example = 1013.0 -> predictions 5065.0\n",
            "key = second_example * 5, example = 108.0 -> predictions 540.0\n",
            "key = first_example * 5, example = 105.0 -> predictions 525.0\n"
          ]
        }
      ],
      "source": [
        "from typing import Tuple\n",
        "\n",
        "def format_output(run_inference_output) -> str:\n",
        "  \"\"\"Takes input from RunInference for scikit-learn and extracts the output.\"\"\"\n",
        "  key, prediction_result = run_inference_output\n",
        "  example = prediction_result.example[0]\n",
        "  prediction = prediction_result.inference[0]\n",
        "  return f\"key = {key}, example = {example} -> predictions {prediction}\"\n",
        "\n",
        "five_times_model_handler = KeyedModelHandler(\n",
        "    SklearnModelHandlerNumpy(model_uri=five_times_model_filename))\n",
        "ten_times_model_handler = KeyedModelHandler(\n",
        "    SklearnModelHandlerNumpy(model_uri=ten_times_model_filename))\n",
        "\n",
        "pipeline_options = PipelineOptions().from_dictionary(\n",
        "                                      {'temp_location':f'gs://{bucket}/tmp'})\n",
        "with beam.Pipeline(options=pipeline_options) as p:\n",
        "  inputs = (p \n",
        "    | \"ReadFromBQ\" >> beam.io.ReadFromBigQuery(table=table_spec))\n",
        "  five_times = (inputs\n",
        "    | \"Extract For 5\" >> beam.Map(lambda x: ('{} {}'.format(x['key'], '* 5'), [x['value']]))\n",
        "    | \"5 times\" >> RunInference(model_handler = five_times_model_handler))\n",
        "  ten_times = (inputs\n",
        "    | \"Extract For 10\" >> beam.Map(lambda x: ('{} {}'.format(x['key'], '* 10'), [x['value']]))\n",
        "    | \"10 times\" >> RunInference(model_handler = ten_times_model_handler))\n",
        "  _ = ((five_times, ten_times) | \"Flattened\" >> beam.Flatten()\n",
        "    | \"format output\" >> beam.Map(format_output)\n",
        "    | \"Print\" >> beam.Map(print))\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
